/* -------------------------------------------------------------------
			 SOURCE CODE: matrix

			      Tom Annau
------------------------------------------------------------------- */

#include "matrix.hh"

#include <math.h>
#include <memory.h>

typedef double *double_ptr;

/* -------------------------------------------------------------------
		     Constructors and destructors
------------------------------------------------------------------- */

matrix::matrix(int rows, int columns)
{
  n_columns = columns;
  n_rows = rows;

  if (n_columns < 1 || n_rows < 1) {
    if (n_columns || n_rows)
      cerr << "ERROR: Attempt to create matrix of illegal dimension (" <<
	rows << "x" << columns << ")\n";
    n_rows = 0;
    n_columns = 0;
    value = NULL;
  }
  else {
    value = new double_ptr[rows];
    for (register int i = 0; i < rows; i++) {
      value[i] = new double[columns];
      memset(value[i], 0, columns * sizeof(double));
    }
  }
}

matrix::matrix(const matrix& m)
{
  n_columns = m.n_columns;
  n_rows = m.n_rows;

  if (m.value) {
    value = new double_ptr[n_rows];
    for (register int i = 0; i < n_rows; i++) {
      value[i] = new double[n_columns];
      memcpy(value[i], m.value[i], n_columns * sizeof(double));
    }
  }
}

matrix::matrix(const vector& v)
{
  if (v.dimension() == 0) {
    n_columns = 0;
    n_rows = 0;
    value = NULL;

    cerr << "ERROR: Attempt to construct matrix from zero-dimensional "
      "vector\n";
  }
  else {
    n_columns = 1;
    n_rows = v.dimension();
    
    value = new double_ptr[n_rows];
    for (register int i = 0; i < n_rows; i++) {
      value[i] = new double[1];
      *(value[i]) = v[i];
    }
  }
}

void matrix::resize(int rows, int columns)
{
  if (rows < 1 && columns < 1)
    if (rows && columns) {
      cerr << "ERROR: Attempt to resize matrix to illegal dimensions\n";
      return;
    }

  if (rows == n_rows && columns == n_columns) return;

  if (value) {
    for (int i = 0; i < n_rows; i++) delete [] value[i];
    delete [] value;
  }

  if (rows && columns) {
    value = new double_ptr[rows];

    for (register int i = 0; i < rows; i++) {
      value[i] = new double[columns];
      memset(value[i], 0, columns * sizeof(double));
    }
  }
  else value = NULL;

  n_rows = rows;
  n_columns = columns;
}

matrix::~matrix(void)
{
  for (int i = 0; i < n_rows; i++)
    delete [] value[i];

  delete [] value;
}

/* -------------------------------------------------------------------
			 Assignment operators
------------------------------------------------------------------- */

matrix& matrix::operator = (const matrix& m)
{
  copy_from(m);
  return *this;
}

/* -------------------------------------------------------------------
			   Logical operators
------------------------------------------------------------------- */

int matrix::operator ==(const matrix & m) const
{
  if ( n_rows != m.n_rows || n_columns != m.n_columns ) return 0;
  
  for( int i=0; i<n_rows; i++ )
	for( int j=0; j<n_columns; j++ )
	   if ( value[i][j] != m.value[i][j] ) return 0;

  return 1;
}

int matrix::operator !=(const matrix & m) const
{
  return ! (*this == m);
}

/* -------------------------------------------------------------------
			   Type conversions
------------------------------------------------------------------- */

string matrix::to_string(void) const return s;
{
  for (int i = 0; i < n_rows - 1; i++) {
    s += vector(n_columns, value[i]).to_string();
    s += "; ";
  }
  if ( n_rows != 0 )
    s += vector(n_columns, value[i]).to_string();
}

string matrix::table_form(void) const return s;
{
  for (int i = 0; i < n_rows; i++) {
    for (int j = 0; j < n_columns - 1; j++) {
      s += string(value[i][j]);
      s += TAB;
    }
    s += string(value[i][j]);
    s += CR;
  }
}

/* -------------------------------------------------------------------
			Arithmetic operations
------------------------------------------------------------------- */

matrix& matrix::operator += (const matrix& m)
{
  if (check_dims(m)) {
    for (register int i = 0; i < n_rows; i++)
	for (register int j = 0; j < n_columns; j++)
	    value[i][j] += m[i][j];
  }
  return *this;
}

matrix& matrix::operator -= (const matrix& m)
{
  if (check_dims(m)) {
    for (register int i = 0; i < n_rows; i++)
	for (register int j = 0; j < n_columns; j++)
	    value[i][j] -= m[i][j];
  }
  return *this;
}

matrix matrix::operator * (const matrix& m) 
    const return result(n_rows, m.n_columns);
{
  if (n_columns != m.n_rows)
    cerr << "ERROR: In matrix multiplication A*B, columns(A) must "
      "equal rows(B)\n";

  else {
    for (register int i = 0; i < n_rows; i++)
	for (register int j = 0; j < m.n_columns; j++) {
	  double sum = 0;
	  for (register int k = 0; k < n_columns; k++)
	      sum += value[i][k] * m.value[k][j];
	  
	  result[i][j] = sum;
	}
  }
}

vector matrix::operator * (const vector& v) 
    const return result(n_rows);
{
  if (n_columns != v.dimension())
    cerr << "ERROR: In matrix/vector multiplication A*x, "
	   "columns(A) must equal dimension(x)\n";
  else {
    for (register int i = 0; i < n_rows; i++) {
      double sum = 0;
      for (register int k = 0; k < n_columns; k++)
	sum += value[i][k] * v[k];
      
      result[i] = sum;
    }
  }
}

matrix operator * (const vector& v, const matrix& m)
     return r(v.dimension(), m.columns());
{
  if (m.n_rows != 1)
    cerr << "ERROR: In vector/matrix multiplication x*A, "
      "A must be a single row";
  else {
    int d = v.dimension();
    for (register int i = 0; i < d; i++)
      for (register int j = 0; j < m.n_columns; j++)
	r.value[i][j] = v[i] * m.value[0][j];
  }
}

matrix outer_product(const vector& v1, const vector& v2)
     return r(v1.dimension(), v2.dimension());
{
  int d1 = v1.dimension();
  int d2 = v2.dimension();
  
  for (register int i = 0; i < d1; i++)
    for (register int j = 0; j < d2; j++)
      r.value[i][j] = v1[i] * v2[j];
}

/* -------------------------------------------------------------------
			    Miscellaneous
------------------------------------------------------------------- */

void matrix::zero(void)
{
  for (register int i = 0; i < n_rows; i++)
    memset(value[i], 0, sizeof(double) * n_columns);
}

void matrix::fill_with(double  scalar)
{
  for (register int i = 0; i < n_rows; i++)
    for (register int j = 0; j < n_columns; j++)
      value[i][j] = scalar;
}

matrix transpose(const matrix& m) return t(m.n_columns, m.n_rows);
{
  for (register int i = 0; i < m.n_rows; i++)
    for (register int j = 0; j < m.n_columns; j++)
      t.value[j][i] = m.value[i][j];
}

/* -------------------------------------------------------------------
			   Stream functions
------------------------------------------------------------------- */

ostream& operator < (ostream& out, const matrix& m)
{
  for (register int i = 0; i < m.n_rows; i++)
    out.write(m.value[i], m.n_columns * sizeof(double));
  return out;
}

istream& operator > (istream& in, matrix& m)
{
  for (register int i = 0; i < m.n_rows; i++)
    in.read(m.value[i], m.n_columns * sizeof(double));

  return in;
}

ostream& operator << (ostream& out, const matrix& m)
{
  out << m.table_form() << CR;
  return out;
}

istream& operator >> (istream& in, matrix& m)
{
  string input_line;
  int r = 0, c;
  
  while (!in.eof()) {
    in >> input_line;
    if (input_line[0] != '#' && !input_line.blank()) {
      if (r >= m.n_rows) {
	cerr << "ERROR: Matrix from input stream has too many rows\n";
	return in;
      }

      input_line.strip_spaces();
      int c = input_line.fields();

      if (c != m.n_columns) {
	cerr << "ERROR: Matrix from input stream has wrong number "
	  "of columns\n";
	return in;
      }
      for (register int i = 0; i < c; i++)
	m.value[r][i] = input_line.field(i).to_double();

      r++;
    }
    else if (input_line.blank()) break;
  }

  if (r != m.n_rows) {
    cerr << "ERROR: Matrix from input stream has too few rows\n";
  }

  return in;
}

/* -------------------------------------------------------------------
			  Protected methods
------------------------------------------------------------------- */

void matrix::copy_from(const matrix& m)
{
  if (check_dims(m))
    for (register int i = 0; i < n_rows; i++)
      memcpy(value[i], m.value[i], n_columns * sizeof(double));
}

void matrix::construct_from_string(string s)
{
  n_rows = 0;
  n_columns = 0;
  value = NULL;
  
  if (s.blank())
    return;
  
  int semicolons = s.frequency(";");
  
  n_rows = semicolons + 1;
  value = new double_ptr[n_rows];

  int l_pos = 0;
  for (int i = 0; i < n_rows; i++) {
    if (l_pos > s.len()) {
      cerr << "ERROR: When reading matrix from string, found terminal ';'\n";
      return;
    }

    int r_pos = s.position(";", i + 1);
    if (r_pos == NOT_FOUND)
      r_pos = s.len();

    vector v(s.middle(l_pos, r_pos - 1));

    if (n_columns == 0) {
      n_columns = v.dimension();
      if (n_columns == 0) {
	cerr << "ERROR: When reading matrix from string, first row is zero "
	  "dimensional\n";
	return;
      }
      for (register int i = 0; i < n_rows; i++) {
	value[i] = new double[n_columns];
	memset(value[i], 0, n_columns * sizeof(double));
      }
    }

    if (n_columns != v.dimension()) {
      cerr << "ERROR: When reading matrix from string, row vectors were "
	"of inconsistent length\n";
      return;
    }
    else memcpy(value[i], v.to_double_ptr(), n_columns * sizeof(double));

    l_pos = r_pos + 1;
  }
}